from typing import Optional
from torch import optim
import torch
import copy
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
MY_INF = 1e12

Device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(Device)

class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, max_seq_len: int):
        super().__init__()

        # Assume d_model is an even number for convenience
        assert d_model % 2 == 0

        i_seq = torch.linspace(0, max_seq_len - 1, max_seq_len)
        j_seq = torch.linspace(0, d_model - 2, d_model // 2)
        pos, two_i = torch.meshgrid(i_seq, j_seq)
        pe_2i = torch.sin(pos / 10000**(two_i / d_model))
        pe_2i_1 = torch.cos(pos / 10000**(two_i / d_model))
        pe = torch.stack((pe_2i, pe_2i_1), 2).reshape(1, max_seq_len, d_model)

        self.register_buffer('pe', pe, False)

    def forward(self, x: torch.Tensor):
        n, seq_len, d_model = x.shape
        pe: torch.Tensor = self.pe
        assert seq_len <= pe.shape[1]
        assert d_model == pe.shape[2]
        rescaled_x = x * d_model**0.5
        return rescaled_x + pe[:, 0:seq_len, :]


def attention(q: torch.Tensor,
              k: torch.Tensor,
              v: torch.Tensor,
              mask: Optional[torch.Tensor] = None):
    '''
    Note: The dtype of mask must be bool
    '''
    # q shape: [n, heads, q_len, d_k]
    # k shape: [n, heads, k_len, d_k]
    # v shape: [n, heads, k_len, d_v]
    assert q.shape[-1] == k.shape[-1]
    d_k = k.shape[-1]
    # tmp shape: [n, heads, q_len, k_len]
    tmp = torch.matmul(q, k.transpose(-2, -1)) / d_k**0.5
    if mask is not None:
        tmp.masked_fill_(mask, -MY_INF)
    tmp = F.softmax(tmp, -1)
    # tmp shape: [n, heads, q_len, d_v]
    tmp = torch.matmul(tmp, v)
    return tmp


class MultiHeadAttention(nn.Module):

    def __init__(self, heads: int, d_model: int, dropout: float = 0.1):
        super().__init__()

        assert d_model % heads == 0
        # dk == dv
        self.d_k = d_model // heads
        self.heads = heads
        self.d_model = d_model
        self.q = nn.Linear(d_model, d_model)
        self.k = nn.Linear(d_model, d_model)
        self.v = nn.Linear(d_model, d_model)
        self.out = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self,
                q: torch.Tensor,
                k: torch.Tensor,
                v: torch.Tensor,
                mask: Optional[torch.Tensor] = None):
        # batch should be same
        assert q.shape[0] == k.shape[0]
        assert q.shape[0] == v.shape[0]
        # the sequence length of k and v should be aligned
        assert k.shape[1] == v.shape[1]

        n, q_len = q.shape[0:2]
        n, k_len = k.shape[0:2]
        q_ = self.q(q).reshape(n, q_len, self.heads, self.d_k).transpose(1, 2)
        k_ = self.k(k).reshape(n, k_len, self.heads, self.d_k).transpose(1, 2)
        v_ = self.v(v).reshape(n, k_len, self.heads, self.d_k).transpose(1, 2)

        attention_res = attention(q_, k_, v_, mask)
        concat_res = attention_res.transpose(1, 2).reshape(
            n, q_len, self.d_model)
        concat_res = self.dropout(concat_res)

        output = self.out(concat_res)
        return output


class FeedForward(nn.Module):

    def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        self.layer1 = nn.Linear(d_model, d_ff)
        self.dropout = nn.Dropout(dropout)
        self.layer2 = nn.Linear(d_ff, d_model)

    def forward(self, x):
        x = self.layer1(x)
        x = self.dropout(F.relu(x))
        x = self.layer2(x)
        return x


class EncoderLayer(nn.Module):

    def __init__(self,
                 heads: int,
                 d_model: int,
                 d_ff: int,
                 dropout: float = 0.1):
        super().__init__()
        self.self_attention = MultiHeadAttention(heads, d_model, dropout)
        self.ffn = FeedForward(d_model, d_ff, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, x, src_mask: Optional[torch.Tensor] = None):
        tmp = self.self_attention(x, x, x, src_mask)
        tmp = self.dropout1(tmp)
        x = self.norm1(x + tmp)
        tmp = self.ffn(x)
        tmp = self.dropout2(tmp)
        x = self.norm2(x + tmp)
        return x


class DecoderLayer(nn.Module):

    def __init__(self,
                 heads: int,
                 d_model: int,
                 d_ff: int,
                 dropout: float = 0.1):
        super().__init__()
        self.self_attention = MultiHeadAttention(heads, d_model, dropout)
        self.attention = MultiHeadAttention(heads, d_model, dropout)
        self.ffn = FeedForward(d_model, d_ff, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.dropout3 = nn.Dropout(dropout)

    def forward(self,
                x,
                encoder_kv: torch.Tensor,
                dst_mask: Optional[torch.Tensor] = None,
                src_dst_mask: Optional[torch.Tensor] = None):
        tmp = self.self_attention(x, x, x, dst_mask)
        tmp = self.dropout1(tmp)
        x = self.norm1(x + tmp)
        tmp = self.attention(x, encoder_kv, encoder_kv, src_dst_mask)
        tmp = self.dropout2(tmp)
        x = self.norm2(x + tmp)
        tmp = self.ffn(x)
        tmp = self.dropout3(tmp)
        x = self.norm3(x + tmp)
        return x


class Encoder(nn.Module):

    def __init__(self,
                 vocab_size: int,
                 pad_idx: int,
                 d_model: int,
                 d_ff: int,
                 n_layers: int,
                 heads: int,
                 dropout: float = 0.1,
                 max_seq_len: int = 120):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model, pad_idx)
        self.pe = PositionalEncoding(d_model, max_seq_len)
        self.layers = []
        for i in range(n_layers):
            self.layers.append(EncoderLayer(heads, d_model, d_ff, dropout))
        self.layers = nn.ModuleList(self.layers)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, src_mask: Optional[torch.Tensor] = None):
        x = self.embedding(x)
        x = self.pe(x)
        x = self.dropout(x)
        for layer in self.layers:
            x = layer(x, src_mask)
        return x


class Decoder(nn.Module):

    def __init__(self,
                 vocab_size: int,
                 pad_idx: int,
                 d_model: int,
                 d_ff: int,
                 n_layers: int,
                 heads: int,
                 dropout: float = 0.1,
                 max_seq_len: int = 120):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model, pad_idx)
        self.pe = PositionalEncoding(d_model, max_seq_len)
        self.layers = []
        for i in range(n_layers):
            self.layers.append(DecoderLayer(heads, d_model, d_ff, dropout))
        self.layers = nn.Sequential(*self.layers)
        self.dropout = nn.Dropout(dropout)

    def forward(self,
                x,
                encoder_kv,
                dst_mask: Optional[torch.Tensor] = None,
                src_dst_mask: Optional[torch.Tensor] = None):
        x = self.embedding(x)
        x = self.pe(x)
        x = self.dropout(x)
        for layer in self.layers:
            x = layer(x, encoder_kv, dst_mask, src_dst_mask)
        return x


class Transformer(nn.Module):

    def __init__(self,
                 src_vocab_size: int,
                 dst_vocab_size: int,
                 pad_idx: int,
                 d_model: int,
                 d_ff: int,
                 n_layers: int,
                 heads: int,
                 dropout: float = 0.1,
                 max_seq_len: int = 200):
        super().__init__()
        self.encoder = Encoder(src_vocab_size, pad_idx, d_model, d_ff,
                               n_layers, heads, dropout, max_seq_len)
        self.decoder = Decoder(dst_vocab_size, pad_idx, d_model, d_ff,
                               n_layers, heads, dropout, max_seq_len)
        self.pad_idx = pad_idx
        self.output_layer = nn.Linear(d_model, dst_vocab_size)

    def generate_mask(self,
                      q_pad: torch.Tensor,
                      k_pad: torch.Tensor,
                      with_left_mask: bool = False):
        # q_pad shape: [n, q_len]
        # k_pad shape: [n, k_len]
        # q_pad k_pad dtype: bool
        assert q_pad.device == k_pad.device
        n, q_len = q_pad.shape
        n, k_len = k_pad.shape

        mask_shape = (n, 1, q_len, k_len)
        if with_left_mask:
            mask = 1 - torch.tril(torch.ones(mask_shape))
        else:
            mask = torch.zeros(mask_shape)
        mask = mask.to(q_pad.device)
        for i in range(n):
            mask[i, :, q_pad[i], :] = 1
            mask[i, :, :, k_pad[i]] = 1
        mask = mask.to(torch.bool)
        return mask

    def forward(self, x, y):

        src_pad_mask = x == self.pad_idx
        dst_pad_mask = y == self.pad_idx
        src_mask = self.generate_mask(src_pad_mask, src_pad_mask, False)
        dst_mask = self.generate_mask(dst_pad_mask, dst_pad_mask, True)
        src_dst_mask = self.generate_mask(dst_pad_mask, src_pad_mask, False)
        encoder_kv = self.encoder(x, src_mask)
        res = self.decoder(y, encoder_kv, dst_mask, src_dst_mask)
        res = self.output_layer(res)
        return res

# src = torch.tensor([[0, 2, 5, 6, 4, 3, 9, 5, 2, 9, 10, 1],
#                     [0, 2, 8, 7, 3, 4, 5, 6, 7, 2, 10, 1]])
# target = torch.tensor([[0, 1, 7, 4, 3, 5, 9, 2, 8, 10, 9, 1],
#                        [0, 1, 5, 6, 2, 4, 7, 6, 2, 8, 10, 1]])

# src_vocab_size: int,
# dst_vocab_size: int,
# pad_idx: int,
# d_model: int,
# d_ff: int,
# n_layers: int,
# heads: int,
# dropout: float = 0.1,
# max_seq_len: int = 200


# t = Transformer(12, 12, 11, 512, 4*512, 6, 8)
# out = t(src, target)
# print(out.size())


class SimpleDataset(Dataset):
    def __init__(self, num):
        prefix = torch.tensor([0])  # 前面插入0
        suffix = torch.tensor([9])  # 后面追加9

        self.data = [torch.randint(1, 9, size=(5,)) for _ in range(num)]
        self.labels = [torch.cat([prefix, _, suffix], dim=0) for _ in copy.deepcopy(self.data)]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]


def create_data(batch_size=5, num=100):
    dataset = SimpleDataset(num)
    data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    return data_loader


def train():

    # 移动模型到设备
    model = Transformer(11, 11, 10, 512, 4 * 512, 6, 8)
    model.to(Device)

    intervals = 30  # 增加训练轮次
    # criterion = nn.CrossEntropyLoss()  # 使用CrossEntropyLoss更合适
    # criterion = nn.CrossEntropyLoss(ignore_index=PAD_ID)
    criterion = nn.CrossEntropyLoss()

    # 使用Adam优化器，学习率减小
    # opt = optim.Adam(model.parameters())
    # opt = optim.SGD(model.parameters(),lr=1e-1)
    opt = optim.Adam(model.parameters(), lr=1e-4, betas=(0.9, 0.98), eps=1e-9)
    # opt = optim.Adam(model.parameters(), lr=0.001)
    # opt = optim.Adam(model.parameters(), lr=1e-2, betas=(0.9, 0.98), eps=1e-9)

    # 学习率调度器
    # scheduler = optim.lr_scheduler.ReduceLROnPlateau(opt, 'min', patience=3, factor=0.5)

    datas = create_data(32, 1000)  # 增加数据量

    for interval in range(intervals):
        model.train()
        loss_all = 0
        for data, label in datas:
            # 移动数据到设备
            data = data.to(Device)
            label = label.to(Device)

            input_label = label[:, :-1]
            output_label = label[:, 1:]

            opt.zero_grad()  # 先清零梯度
            output = model(data, input_label)

            # 计算损失
            loss = criterion(output.reshape(-1, output.size(-1)), output_label.reshape(-1))
            loss_all += loss.item()

            loss.backward()
            # 梯度裁剪，防止梯度爆炸
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            opt.step()

        avg_loss = loss_all / len(datas)
        # scheduler.step(avg_loss)  # 根据损失调整学习率
        print(f'第 {interval + 1} 轮, 平均损失: {avg_loss:.6f}, 学习率: {opt.param_groups[0]["lr"]:.8f}')

    model.eval()
    with torch.no_grad():
        # 测试数据
        test_in_datas = [1, 2, 3, 4, 5]
        test_in_data = torch.Tensor(test_in_datas).long().view(1, -1)

        # label = [1,2,3,4,5]
        label = [0]
        label = torch.Tensor(label).long().view(1, -1)

        while 1:
            test_in_data = test_in_data.to(Device)
            label = label.to(Device)

            mask = torch.triu(torch.ones(len(label), len(label)), diagonal=1) == 0
            # mask = mask.view(1,1,1,1)
            mask_self = mask.to(Device)

            output = model(test_in_data, label)
            log_softmax = F.log_softmax(output, dim=-1)
            # print(log_softmax)
            max_value = log_softmax.argmax(dim=-1)
            # label = max_value[-1]
            label = torch.cat([label, max_value[:, -1].view(1, -1)], dim=1)
            print(label)
            if max_value[:, -1].item() == 9:  # 9为结束符
                break




if __name__ == '__main__':
    train()
